#include "sk.h"

#include <memory>

namespace net {
    static bool inited = false;

    //winsock库需要特殊的初始化
    static bool WinNetInit() {
        WORD sockVersion = MAKEWORD(2, 2);
        WSADATA wsdata;

        if (WSAStartup(sockVersion, &wsdata) == 0)
            return (inited = true);
        return false;
    }

    namespace sk {
        Socket::Socket(int family, int type, int protocol) {
            if (!inited)
                WinNetInit();
            sp = ::socket(family, type, protocol);
        }

        Socket::~Socket() {
            if (isUsable()) {
                closesocket(sp);
            }
        }

        bool Socket::connect(string ip, int port) {
            sockaddr_in addr;

            memset(&addr, 0, sizeof(sockaddr_in));
            addr.sin_family = getFamily();
            addr.sin_port = htons(port);
            addr.sin_addr.S_un.S_addr = inet_addr(ip.c_str());
            return ::connect(sp, (sockaddr*)&addr, sizeof(sockaddr_in)) == 0;
        }

        bool Socket::bind(string ip, int port) {
            sockaddr_in addr;

            memset(&addr, 0, sizeof(addr));
            addr.sin_family = getFamily();
            addr.sin_port = htons(port);

            if (ip.length() == 0)
                addr.sin_addr.S_un.S_addr = INADDR_ANY; //自动分配IP
            else
                addr.sin_addr.S_un.S_addr = inet_addr(ip.c_str());
            
            return ::bind(sp, (sockaddr*)&addr, sizeof(addr)) == 0;
        }

        bool Socket::listen(int n) {
            return ::listen(sp, n) == 0;
        }

        Socket Socket::accept() {
            sockaddr_in addr;
            int l = sizeof(addr);
            SOCKET ns = ::accept(sp, (sockaddr *)&addr, &l);
            return Socket(ns);
        }

        size_t Socket::send(const char *data, int l) {
            return ::send(sp, data, l, 0);
        }

        size_t Socket::send(string & s) {
            return send(s.c_str(), s.length());
        }

        size_t Socket::send(vector<char> & v) {
            return send(v.data(), v.size());
        }

        size_t Socket::sendto(string ip, int port, const char *data, int l) {
            sockaddr_in addr;

            memset(&addr, 0, sizeof(addr));
            addr.sin_family = getFamily();
            addr.sin_port = htons(port);
            addr.sin_addr.s_addr = inet_addr(ip.c_str());

            return ::sendto(sp, data, l, 0, (sockaddr*)&addr, sizeof(addr));
        }
        
        size_t Socket::sendto(string ip, int port, string & s) {
            return sendto(ip, port, s.c_str(), s.size());
        }
        
        size_t Socket::sendto(string ip, int port, vector<char> & v) {
            return sendto(ip, port, v.data(), v.size());
        }

        int Socket::recv(char *buff, size_t buffl) {
            return ::recv(sp, buff, buffl, 0);
        }

        int Socket::recv(vector<char> & result, size_t l) {
            std::unique_ptr<char[]> buff(new char[l]);
            int rl = recv(buff.get(), l);

            for (int i = 0; i < rl; i ++) {
                result.push_back(buff.get()[i]);
            }

            return rl;
        }

        int Socket::recv(string & result, size_t l) {
            std::unique_ptr<char[]> buff(new char[l]);
            int rl = recv(buff.get(), l);

            for (int i = 0; i < rl; i ++) {
                result.push_back(buff.get()[i]);
            }

            return rl;
        }

        int Socket::recvAll(vector<char> & result) {
            int rl = 0, tl = 0;

            do {
                tl = recv(result);
                rl += tl;
            } while (tl == 512);

            return rl;
        }

        int Socket::recvAll(string & result) {
            int rl = 0, tl = 0;
            constexpr int buffLength = 1024;

            do {
                tl = recv(result, buffLength);
                rl += tl;
            } while (tl == buffLength);

            return rl;
        }

        int Socket::recvFrom(string & ip, int & port, char *buff, size_t l) {
            sockaddr_in addr;
            int al = sizeof(addr), ol;
            char tmp[INET_ADDRSTRLEN]; // INET_ADDRSTRLEN 是IPv4地址字符串长度，通常是16

            memset(&addr, 0, al);
            ol = ::recvfrom(sp, buff, l, 0, (sockaddr *)&addr, &al);
            ip = inet_ntoa(addr.sin_addr);
            port = ntohs(addr.sin_port);

            return ol;
        }
        
        int Socket::recvFrom(string & ip, int & port, string & s, size_t l) {
            std::unique_ptr<char[]> buff(new char[l]);

            int sl = recvFrom(ip, port, buff.get(), l);
            for (int i = 0; i < sl; i ++) {
                s += buff.get()[i];
            }
            return sl;
        }
        
        int Socket::recvFrom(string & ip, int & port, vector<char> & v, size_t l) {
            std::unique_ptr<char[]> buff(new char[l]);

            int sl = recvFrom(ip, port, buff.get(), l);
            for (int i = 0; i < sl; i ++) {
                v.push_back(buff.get()[i]);
            }
            return sl;
        }

        bool Socket::isUsable() {
            WSAPROTOCOL_INFOA protocolInfo;
            constexpr int l = sizeof(protocolInfo);
            return ::getsockopt(sp, SOL_SOCKET, SO_PROTOCOL_INFOA, (char*)&protocolInfo, (int*)&l) == 0;
        }

        int Socket::getFamily() {
            WSAPROTOCOL_INFOA protocolInfo;
            constexpr int l = sizeof(protocolInfo);
            if (::getsockopt(sp, SOL_SOCKET, SO_PROTOCOL_INFOA, (char*)&protocolInfo, (int*)&l) == 0) {
                return protocolInfo.iAddressFamily;
            }
            return -1;
        }
        
        int Socket::getType() {
            WSAPROTOCOL_INFOA protocolInfo;
            int l = sizeof(protocolInfo);
            if (::getsockopt(sp, SOL_SOCKET, SO_PROTOCOL_INFOA, (char*)&protocolInfo, &l) == 0) {
                return protocolInfo.iSocketType;
            }
            return -1;
        }
        
        int Socket::getProtocol() {
            WSAPROTOCOL_INFOA protocolInfo;
            int l = sizeof(protocolInfo);
            if (::getsockopt(sp, SOL_SOCKET, SO_PROTOCOL_INFOA, (char*)&protocolInfo, &l) == 0) {
                return protocolInfo.iProtocol;
            }
            return -1;
        }

        // int Socket::getDataLength() {
        //     unsigned long l;
        //     if (::ioctlsocket(sp, FIONREAD, &l) != 0) {
        //         int error = WSAGetLastError();
        //         printf("Error:%d\n", error);
        //         return -1;
        //     }
        //     return l;
        // }

        Socket newTCPSocket() {
            return Socket();
        }

        Socket newTCPClient(string sip, int sport) {
            Socket s = Socket();
            socket_t sp = s.getPtr();

            if (!s.connect(sip.c_str(), sport)) {
                return Socket();
            }

            s.setPtr(0);
            return Socket(sp);
        }

        Socket newTCPSever(string sip, int sport, int la) {
            Socket s = newTCPSocket();
            socket_t sp = s.getPtr();

            if (!s.bind(sip, sport)) {
                return (Socket)0;
            }

            if (!s.listen(la)) {
                return (Socket)0;
            }

            s.setPtr(0);
            return Socket(sp);
        }
        
        Socket newUDPSocket() {
            return Socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
        }
    }
}

